// Copyright (c) The Move Contributors SPDX-License-Identifier: Apache-2.0

use crate::ast::{self, Exp};

use move_binary_format::normalized::Constant;
use move_core_types::runtime_value::MoveValue as Value;
use move_model_2::{model, source_kind::SourceKind, summary};
use move_stackless_bytecode_2::ast::{DataOp, PrimitiveOp};
use move_symbol_pool::Symbol;
use pretty_simple::{Doc, Doc as D, ToDoc};

use indexmap::IndexMap;

// -------------------------------------------------------------------------------------------------
// Render Context
// -------------------------------------------------------------------------------------------------

struct Context {
    constant_table: IndexMap<*const Constant<Symbol>, String>,
}

impl Context {
    fn get_constant(&self, c: &Constant<Symbol>) -> Doc {
        let key = c as *const _;
        match self.constant_table.get(&key) {
            Some(name) => D::text(name),
            None => D::text(format!("/* unknown constant {:p} */", key)),
        }
    }
}

// -------------------------------------------------------------------------------------------------
// Render Functions
// -------------------------------------------------------------------------------------------------

pub fn module<S: SourceKind>(
    model: &model::Model<S>,
    pkg_name: &str,
    model_mod: model::Module<'_, S>,
    module: &ast::Module,
) -> anyhow::Result<Doc> {
    // We have no good way to identify constants: in the model, they return the information held in
    // their definition RC, plus their module and actual bytes. We need a way to compare these to
    // the constants we find in the bytecode so we can link them up.
    // We do have one option: we can write down the actual in-memory pointer of the constant
    // definition RC, and use that as a key. This is not ideal, but it should be stable enough for
    // our purposes here.
    let constant_table: IndexMap<
        *const move_model_2::normalized::Constant,
        (String, model::CompiledConstant<S>),
    > = model_mod
        .compiled_constants()
        .enumerate()
        .map(|(ndx, c)| (c.compiled() as *const _, (format!("C{ndx}"), c)))
        .collect();

    let context = {
        let constant_table: IndexMap<*const Constant<Symbol>, String> = constant_table
            .iter()
            .map(|(k, (name, _))| (*k, name.clone()))
            .collect();

        Context { constant_table }
    };

    let crate::ast::Module { name, functions } = module;

    let mut doc = D::text("// Auto-generated by Move decompiler")
        .concat(D::line())
        .concat(D::line());

    doc = doc
        .concat(D::text("module"))
        .concat_space(D::text(pkg_name))
        .concat(D::text("::"))
        .concat(D::text(name.as_str()))
        .concat(D::text(";"))
        .concat(D::line())
        .concat(D::line());

    if model_mod.structs().next().is_some() {
        doc = doc
            .concat(D::text("// -- structs -- "))
            .concat(D::line())
            .concat(D::line());

        let structs = {
            let mut doc = D::nil();
            for s in model_mod.structs() {
                let s_doc = s.to_doc();
                doc = doc.concat(s_doc).concat(D::line()).concat(D::line());
            }
            doc
        };
        doc = doc.concat(structs);
    }

    if model_mod.enums().next().is_some() {
        doc = doc
            .concat(D::text("// -- enums -- "))
            .concat(D::line())
            .concat(D::line());

        let enums = {
            let mut doc = D::nil();
            for e in model_mod.enums() {
                let e_doc = e.to_doc();
                doc = doc.concat(e_doc).concat(D::line()).concat(D::line());
            }
            doc
        };
        doc = doc.concat(enums);
    }

    if !constant_table.is_empty() {
        doc = doc
            .concat(D::text("// -- constants -- "))
            .concat(D::line())
            .concat(D::line());

        let constants = {
            let summary_context = summary::Context::new(model);

            let mut doc = D::nil();
            for (_, (name, constant)) in constant_table.iter() {
                let const_ = constant.compiled();
                let c_doc = D::text(format!("const {name}"))
                    .concat(D::text(":"))
                    .concat_space(
                        summary::Type::from_normalized(&summary_context, &const_.type_).to_doc(),
                    )
                    .group()
                    .concat_space(D::text("="))
                    .concat_space(value(constant.value()))
                    .concat(D::text(";"));
                doc = doc.concat(c_doc).concat(D::line()).concat(D::line());
            }
            doc
        };
        doc = doc.concat(constants);
    }

    if !functions.is_empty() {
        doc = doc
            .concat(D::text("// -- functions -- "))
            .concat(D::line())
            .concat(D::line());

        let functions = {
            let mut doc = D::nil();
            for (name, fun) in functions {
                let Some(model_fun) = model_mod.maybe_function(*name) else {
                    anyhow::bail!("Function {} not found in module {}", name, module.name);
                };
                let f_doc = function(&context, &model_fun, fun);
                doc = doc.concat(f_doc).concat(D::line()).concat(D::line());
            }
            doc
        };
        doc = doc.concat(functions);
    }

    Ok(doc)
}

fn function<S: SourceKind>(
    context: &Context,
    model_fun: &model::Function<'_, S>,
    fun: &crate::ast::Function,
) -> Doc {
    // TODO: Docs, Attributes
    let header =
        move_model_2::pretty_printer::fun_header(model_fun, /* use_param_names */ false);

    let crate::ast::Function { name: _, code } = fun;
    let exp_doc = exp(context, code);

    header
        .concat_space(Doc::text("{"))
        .concat(Doc::nest(Doc::line().concat(exp_doc), 4))
        .concat(Doc::line())
        .concat(Doc::text("}"))
}

fn exp(context: &Context, exp: &Exp) -> Doc {
    fn braces_block(body: Doc) -> Doc {
        D::braces(D::line().concat(body.indent(4)).concat(D::line()))
    }

    // Render a list of statements separated by lines.
    fn stmts<'a, I>(context: &Context, it: I) -> Doc
    where
        I: IntoIterator<Item = &'a Exp>,
    {
        D::intersperse(
            it.into_iter().map(|e| recur(context, e)),
            D::text(";").concat(D::line()),
        )
    }

    // Expression-ish printers --------------------------------------------

    fn recur(context: &Context, e: &Exp) -> Doc {
        match e {
            Exp::Break => D::text("break"),
            Exp::Continue => D::text("continue"),
            Exp::Return(es) => match &es[..] {
                [] => D::text("return"),
                [x] => D::text("return").concat_space(recur(context, x)),
                _ => D::text("return").concat_space(exp_list(context, es).parens()),
            },
            Exp::Assign(lhs, rhs) => match &lhs[..] {
                [] => recur(context, rhs),
                [x] => D::text(x)
                    .concat_space(D::text("="))
                    .concat_space(recur(context, rhs)),
                _ => {
                    let lhs =
                        D::intersperse(lhs.iter().map(D::text), D::text(",").concat(D::space()))
                            .parens()
                            .group();
                    lhs.concat_space(D::text("="))
                        .concat_space(recur(context, rhs))
                }
            },
            Exp::LetBind(lhs, rhs) => {
                let lhs_doc = match &lhs[..] {
                    [] => D::text("_"),
                    [x] => D::text(x),
                    _ => D::parens(D::intersperse(
                        lhs.iter().map(D::text),
                        D::text(",").concat(D::space()),
                    ))
                    .group(),
                };
                D::text("let")
                    .concat_space(lhs_doc)
                    .concat_space(D::text("="))
                    .concat_space(recur(context, rhs))
            }
            Exp::Call((m, f), args) => {
                D::text(format!("{m}::{f}")).concat(exp_list(context, args).parens())
            }
            Exp::Abort(e) => D::text("abort").concat_space(recur(context, e)),
            Exp::Borrow(mutable, e) => {
                if *mutable {
                    D::text("&mut").concat_space(recur(context, e))
                } else {
                    D::text("&").concat(recur(context, e))
                }
            }
            Exp::Value(v) => value(v),
            Exp::Variable(s) => D::text(s),
            Exp::Constant(c) => context.get_constant(c),
            Exp::Seq(vs) => {
                let final_semi = matches!(vs.last(), Some(Exp::LetBind(_, _) | Exp::Assign(_, _)));
                let mut stmts = stmts(context, vs);
                if final_semi {
                    stmts = stmts.concat(D::text(";"));
                }
                braces_block(stmts)
            }
            Exp::Loop(b) => D::text("loop").concat_space(e_block(context, b)),
            Exp::While(c, b) => while_doc(context, c, b),
            Exp::IfElse(c, t, e) => if_doc(context, c, t, e),
            Exp::Switch(subject, (mid, enum_), arms) => {
                let arms_doc = Doc::intersperse(
                    arms.iter().map(|(variant, body)| {
                        D::text(variant.as_str())
                            .concat_space(D::text("=>"))
                            .concat_space(e_block(context, body))
                    }),
                    D::text(",").concat(D::line()),
                );
                D::text(format!("switch {mid}::{enum_}"))
                    .concat_space(D::parens(recur(context, subject)))
                    .concat_space(braces_block(arms_doc))
            }
            Exp::Primitive { op, args } => primitive_op_doc(context, op, args),
            Exp::Data { op, args } => data_op_doc(context, op, args),
            Exp::Unpack((mod_, struct_), items, exp) => {
                let items_doc = fields(items);
                D::text(format!("{mod_}::{struct_}"))
                    .concat_space(items_doc)
                    .concat_space(D::text("="))
                    .concat_space(recur(context, exp))
            }
            Exp::VecUnpack(lhs, exp) => {
                if lhs.is_empty() {
                    D::text("std::vector::destroy_empty").concat(recur(context, exp).parens())
                } else {
                    D::text("/* UNSUPPORT OP: MULTIARG VEC UNPACK ON ")
                        .concat(recur(context, exp))
                        .concat(D::text(" */"))
                }
            }
            Exp::UnpackVariant(_unpack_kind, (mod_, enum_, variant), items, exp) => {
                let items_doc = fields(items);
                D::text(format!("{mod_}::{enum_}::{variant}"))
                    .concat_space(items_doc)
                    .concat_space(D::text("="))
                    .concat_space(recur(context, exp))
            }
        }
    }

    fn e_block(context: &Context, e: &Exp) -> Doc {
        // If it’s already a block/seq, print its statements;
        // otherwise, treat the single expression as a statement.
        match e {
            Exp::Seq(vs) => {
                let final_semi = matches!(vs.last(), Some(Exp::LetBind(_, _) | Exp::Assign(_, _)));
                let mut stmts = stmts(context, vs);
                if final_semi {
                    stmts = stmts.concat(D::text(";"));
                }
                braces_block(stmts)
            }
            other => {
                let final_semi = matches!(other, Exp::LetBind(_, _) | Exp::Assign(_, _));
                let mut body = recur(context, other);
                if final_semi {
                    body = body.concat(D::text(";"));
                }
                braces_block(body)
            }
        }
    }

    fn while_doc(context: &Context, cond: &Exp, body: &Exp) -> Doc {
        D::text("while")
            .concat_space(recur(context, cond).parens())
            .concat_space(e_block(context, body))
    }

    fn if_doc(context: &Context, cond: &Exp, then_b: &Exp, else_b: &Option<Exp>) -> Doc {
        let then_block = e_block(context, then_b);
        match else_b {
            None => D::text("if")
                .concat_space(recur(context, cond).parens())
                .concat_space(then_block),
            Some(e) => {
                let else_block = e_block(context, e);
                D::text("if")
                    .concat_space(recur(context, cond).parens())
                    .concat_space(then_block)
                    .concat_space(D::text("else"))
                    .concat_space(else_block)
            }
        }
    }
    recur(context, exp)
}

/// Render a list of expressions separated by commas.
fn exp_list<'a, I>(context: &Context, it: I) -> Doc
where
    I: IntoIterator<Item = &'a Exp>,
{
    D::intersperse(
        it.into_iter().map(|e| exp(context, e)),
        D::text(",").concat(D::space()),
    )
}

/// Render a list of fields (name, type) separated by commas and enclosed in braces.
fn fields(fields: &[(Symbol, String)]) -> Doc {
    if fields.is_empty() {
        return D::nil().braces();
    };
    let doc = D::intersperse(
        fields.iter().map(|(name, ty)| {
            D::text(name.as_str())
                .concat(D::text(":"))
                .concat_space(D::text(ty))
        }),
        D::text(",").concat(D::space()),
    );
    D::space().concat(doc).concat(D::space()).braces()
}

fn data_op_doc(context: &Context, op: &DataOp, args: &[Exp]) -> Doc {
    fn maybe_parens(context: &Context, e: &Exp) -> Doc {
        match e {
            Exp::Variable(_) | Exp::Value(_) | Exp::Constant(_) => exp(context, e),
            _ => exp(context, e).parens(),
        }
    }

    match op {
        DataOp::ReadRef => D::text("*").concat(maybe_parens(context, &args[0])),

        DataOp::WriteRef => D::text("*")
            .concat(maybe_parens(context, &args[0]))
            .concat_space(D::text("="))
            .concat_space(exp(context, &args[1])),
        DataOp::FreezeRef => D::text("freeze").concat(exp(context, &args[0]).parens()),

        DataOp::MutBorrowField(field_ref) => D::text("&mut ")
            .concat(maybe_parens(context, &args[0]))
            .concat(D::text("."))
            .concat(D::text(field_ref.field.name.as_str())),

        DataOp::ImmBorrowField(field_ref) => D::text("&")
            .concat(maybe_parens(context, &args[0]))
            .concat(D::text("."))
            .concat(D::text(field_ref.field.name.as_str())),

        DataOp::VecPack(_) => D::text("vec![")
            .concat(exp_list(context, args))
            .concat(D::text("]")),

        DataOp::VecLen(_) => exp(context, &args[0]).concat(D::text(".len()")),

        DataOp::VecImmBorrow(_) => D::text("&")
            .concat(maybe_parens(context, &args[0]))
            .concat(D::text("["))
            .concat(exp(context, &args[1]))
            .concat(D::text("]")),

        DataOp::VecMutBorrow(_) => D::text("&mut ")
            .concat(maybe_parens(context, &args[0]))
            .concat(D::text("["))
            .concat(exp(context, &args[1]))
            .concat(D::text("]")),

        DataOp::VecPushBack(_) => maybe_parens(context, &args[0])
            .concat(D::text(".push_back("))
            .concat(exp(context, &args[1]))
            .concat(D::text(")")),

        DataOp::VecPopBack(_) => maybe_parens(context, &args[0])
            .concat(D::text(".pop_back("))
            .concat(exp(context, &args[1]))
            .concat(D::text(")")),

        DataOp::VecSwap(_) => maybe_parens(context, &args[0])
            .concat(D::text(".swap("))
            .concat(exp(context, &args[1]))
            .concat(D::text(", "))
            .concat(exp(context, &args[2]))
            .concat(D::text(")")),

        DataOp::PackVariant(variant) => {
            let fields = &variant.variant.fields.0;
            assert!(fields.len() == args.len());
            let enum_name = variant.enum_.name;
            let variant_name = variant.variant.name;
            D::text(format!("{enum_name}::{variant_name}")).concat_space(if fields.is_empty() {
                D::nil().braces()
            } else {
                D::space()
                    .concat(D::intersperse(
                        fields.iter().zip(args.iter()).map(|((name, _ty), e)| {
                            D::text(name.as_str())
                                .concat(D::text(":"))
                                .concat_space(exp(context, e))
                        }),
                        D::text(",").concat(D::space()),
                    ))
                    .concat(D::space())
                    .braces()
            })
        }

        DataOp::Pack(struct_) => {
            let fields = &struct_.struct_.fields.0;
            assert!(fields.len() == args.len());
            let struct_name = struct_.struct_.name;
            D::text(format!("{struct_name}")).concat_space(if fields.is_empty() {
                D::nil().braces()
            } else {
                D::space()
                    .concat(D::intersperse(
                        fields.iter().zip(args.iter()).map(|((name, _ty), e)| {
                            D::text(name.as_str())
                                .concat(D::text(":"))
                                .concat_space(exp(context, e))
                        }),
                        D::text(",").concat(D::space()),
                    ))
                    .concat(D::space())
                    .braces()
            })
        }

        DataOp::Unpack(_) => unreachable!("Unpack"),
        DataOp::VecUnpack(_) => unreachable!("VecUnpack"),

        DataOp::UnpackVariant(_)
        | DataOp::UnpackVariantImmRef(_)
        | DataOp::UnpackVariantMutRef(_) => unreachable!("Unpack variant"),
    }
}

fn primitive_op_doc(context: &Context, op: &PrimitiveOp, args: &[Exp]) -> Doc {
    let bin = |lhs: &Exp, sym: &str, rhs: &Exp| {
        exp(context, lhs)
            .concat_space(D::text(sym.to_string()))
            .concat_space(exp(context, rhs))
    };

    match op {
        PrimitiveOp::CastU8 => exp(context, &args[0]).concat(D::text("as u8")),
        PrimitiveOp::CastU16 => exp(context, &args[0]).concat(D::text("as u16")),
        PrimitiveOp::CastU32 => exp(context, &args[0]).concat(D::text("as u32")),
        PrimitiveOp::CastU64 => exp(context, &args[0]).concat(D::text("as u64")),
        PrimitiveOp::CastU128 => exp(context, &args[0]).concat(D::text("as u128")),
        PrimitiveOp::CastU256 => exp(context, &args[0]).concat(D::text("as u256")),

        PrimitiveOp::Add => bin(&args[0], "+", &args[1]),
        PrimitiveOp::Subtract => bin(&args[0], "-", &args[1]),
        PrimitiveOp::Multiply => bin(&args[0], "*", &args[1]),
        PrimitiveOp::Modulo => bin(&args[0], "%", &args[1]),
        PrimitiveOp::Divide => bin(&args[0], "/", &args[1]),
        PrimitiveOp::BitOr => bin(&args[0], "|", &args[1]),
        PrimitiveOp::BitAnd => bin(&args[0], "&", &args[1]),
        PrimitiveOp::Xor => bin(&args[0], "^", &args[1]),
        PrimitiveOp::Or => bin(&args[0], "||", &args[1]),
        PrimitiveOp::And => bin(&args[0], "&&", &args[1]),
        PrimitiveOp::Equal => bin(&args[0], "==", &args[1]),
        PrimitiveOp::NotEqual => bin(&args[0], "!=", &args[1]),
        PrimitiveOp::LessThan => bin(&args[0], "<", &args[1]),
        PrimitiveOp::GreaterThan => bin(&args[0], ">", &args[1]),
        PrimitiveOp::LessThanOrEqual => bin(&args[0], "<=", &args[1]),
        PrimitiveOp::GreaterThanOrEqual => bin(&args[0], ">=", &args[1]),

        PrimitiveOp::Not => D::text("!").concat(exp(context, &args[0]).parens()),

        PrimitiveOp::ShiftLeft => bin(&args[0], "<<", &args[1]),
        PrimitiveOp::ShiftRight => bin(&args[0], ">>", &args[1]),
    }
}

fn value(v: &Value) -> Doc {
    match v {
        Value::Bool(b) => D::text(b.to_string()),
        Value::U8(u) => D::text(u.to_string()).concat(D::text("u8")),
        Value::U16(u) => D::text(u.to_string()).concat(D::text("u16")),
        Value::U32(u) => D::text(u.to_string()).concat(D::text("u32")),
        Value::U64(u) => D::text(u.to_string()).concat(D::text("u64")),
        Value::U128(u) => D::text(u.to_string()).concat(D::text("u128")),
        Value::U256(u) => D::text(u.to_string()).concat(D::text("u256")),
        Value::Address(a) => D::text(format!("@{:X}", a)),
        Value::Vector(values) => D::text("vec![")
            .concat(D::intersperse(
                values.iter().map(value),
                D::text(",").concat(D::space()),
            ))
            .concat(D::text("]")),
        Value::Struct(_) | Value::Signer(_) | Value::Variant(_) => unreachable!(),
    }
}
